-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jetstream by default #118
Jetstream by default #118
Conversation
Implementation is slightly different, so a separate test is added.
Most tests work for both, except for the continuous batching one. This allows to remove the old GPT2 based tests, that are quite slow and do not use any sharding or KV cache, so they might not really be representative of most relevant models on TGI.
There are equivalent tests now on the TinyLlama model, that run faster, use the KV cache and sharding. The only test that does not have an equivalence is the continuous batching one, but the test was not working for most other models, so I prefer to remove it anyway, as having it passing was not representative anyway of the current state.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Now that the engine is stable and tested, its engine is set as the default one for TGI.
078dfa4
to
07f74ff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I got lost in your changes: can you summarize how tests are now supposed to work ?
ids=["spaces", "chinese-utf8", "emojis"], | ||
) | ||
def test_decode_streaming_jetstream(tokenizer, input_text, generated_text): | ||
if not jetstream_pt_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that you could have created a decorator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored the test to avoid repetitions.
assert generations[0].tokens.texts == [" the"] | ||
|
||
|
||
def test_prefill_truncate_jetstream(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fail to see the difference between the two tests: I don't think it was required to add the 'jetstream' one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two tests are identical in behaviour, but if jetstream is loaded the other test will fail to run correctly because of incompatibility on the dependencies when using some features of pytorch (i.e.: multiprocessing).
I just have two identical tests, but one is going to be run when jetstream is enabled, the other one will be skipped, and when jetstream is disabled it will be the other way around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not only identical in behaviour: this is the same test with two different names ... What am I missing ?
_test_continuous_batching_two_requests(model_path) | ||
|
||
|
||
"""NOTE: This test does not work on PyTorch/XLA, because of the way |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should adapt the test to make it actually useful for the XLA configuration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my tests, with BF16 and KV cache, I was not able to get this test working. I think there might be an issue on the way KV cache is implemented, because this test is successful on the Jetstream backend with BF16 on the same hardware. This is the reason why I left the test there, as a reminder that this should be done later on, but that it does not really work as expected for now.
33fa858
to
5de0979
Compare
So far filtering was done using the name of the test. Now the selection is done using a custom marker, that allows for clearer filtering.
5de0979
to
6704a80
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! Thank you for this pull-request.
# Skip tests that require torch xla but not jetstream | ||
if "torch_xla" in marker_names and "jetstream" not in marker_names: | ||
if jetstream_pt_enabled: | ||
pytest.skip("Jetstream PyTorch must be disabled") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would find it clearer to say sthg like "Jetstream is enabled: xla test will fail".
@dacorvo as discussed offline, the idea is to change the default backend of TPU TGI from torch xla to jetstream. |
For some reason the env var was not carried on (though Jetstream was disabled anyway). Moving the variable to the command line invocation will remove a warning in the logs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JETSTREAM_PT_DISABLE=1
Some tests result change when operations are done in a slightly different way. This has happened now with the torch xla tests, resulting in different results on the CI. To avoid this, now tests compare the obtained token and text is different from the one obtained when running with greedy search.
What does this PR do?
This makes all the changes to allow having the Jetstream Pytorch engine to be the default backend for TGI on TPUs.
This backend is reliable and performant and give the best throughput on TGI.